import re
import torch
import json
import requests
from io import BytesIO
from PIL import Image
import sys
import os
from tqdm import tqdm 
from pathlib import Path

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from tci_attn import LlamaAttentionWithLogits

def load_image(image_file: str) -> Image.Image:
    if image_file.startswith("http://") or image_file.startswith("https://"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def load_images(image_files: list[str]) -> list[Image.Image]:
    out = []
    for image_file in image_files:
        img = load_image(image_file)
        out.append(img)
    return out

def image_parser(image_arg: str, sep: str) -> list[str]:
    return image_arg.split(sep)

def eval_model(args):


    disable_torch_init()

    # 1. load model
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=args.model_path,
        model_base=args.model_base,
        model_name=model_name,
        torch_dtype=torch.float16 if hasattr(args, "torch_dtype") else torch.float16,
    )
    model.to(model.device)

    output_file = os.path.join(args.output_folder, args.output_json_file)

    # reset attention modules in model 
    if args.tci == True:
        for i, layer in enumerate(model.model.layers):
            if i in [0, 1, 14, 15, 17]:    
                attn_adap = LlamaAttentionWithLogits(layer.self_attn.config, layer_idx=i, alpha=args.alpha)
                attn_adap.load_state_dict(layer.self_attn.state_dict())
                attn_adap = attn_adap.half().to(model.device)
                layer.self_attn = attn_adap


    with open(output_file, "w") as f_out:
        for data in tqdm(args.data,desc="Processing images"):
            image_file = os.path.join(args.image_folder, data["image"])

            question = data.get("text")

            image = load_image(image_file)

            images_tensor = process_images([image], image_processor, model.config).to(
                model.device, dtype=torch.float16
            )

            #qs = args.query
            qs = question
            image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
            if IMAGE_PLACEHOLDER in qs:
                if model.config.mm_use_im_start_end:
                    qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
                else:
                    qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
            else:
                if model.config.mm_use_im_start_end:
                    qs = image_token_se + "\n" + qs
                else:
                    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

            if "llama-2" in model_name.lower():
                conv_mode = "llava_llama_2"
            elif "mistral" in model_name.lower():
                conv_mode = "mistral_instruct"
            elif "v1.6-34b" in model_name.lower():
                conv_mode = "chatml_direct"
            elif "v1" in model_name.lower():
                conv_mode = "llava_v1"
            elif "mpt" in model_name.lower():
                conv_mode = "mpt"
            else:
                conv_mode = "llava_v0"

            conv_mode = args.conv_mode

            conv = conv_templates[conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            input_ids = (
                tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
                .unsqueeze(0)
                .to(model.device)
            )

            # generate   
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=images_tensor,
                    image_sizes=[image.size],
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                )

            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            #print(outputs)

            #{"image": "002.jpg", "text": "How many uncut fruits are in the image?", "category": "conv", "question_id": 5}
            result = {
                "image": data["image"],
                "text": data["text"],
                "category": data["category"],
                "question_id": data["question_id"],
                "model_id": "LLaVA-1.5-7B-W+2",
                "model_answer": outputs,
            }

            f_out.write(json.dumps(result) + "\n")


def main():
    model_path = "llava-v1.5-7b"
    image_folder = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
    output_folder = "tci/results/llava_bench"
    input_json_file = "tci/llava_bench_in_the_wild/questions.jsonl"
    output_json_file = "llava_bench_llava_W+2.jsonl"

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    data = []
    with open(input_json_file, "r") as f:
        for line in f:
            if line.strip():    
                data.append(json.loads(line))

    args = type("Args", (), {})()
    args.model_path = model_path
    args.model_base = None
    args.output_json_file = output_json_file
    #args.query = prompt
    args.conv_mode = "vicuna_v1"  
    args.image_folder = image_folder
    args.output_folder = output_folder
    args.data = data
    args.sep = "," 
    args.temperature = 0.0
    args.top_p = None
    args.num_beams = 1
    args.max_new_tokens = 512
    args.torch_dtype = torch.float16

    args.tci = True
    args.alpha = 2

    eval_model(args)

if __name__ == "__main__":
    main()